{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd",
   "metadata": {},
   "source": [
    "### Building a Chatbot Interface, with Text or Voice Input, Multi-LLM support, and Memory Persistence"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eeb20b3e",
   "metadata": {},
   "source": [
    "In this tutorial, we’ll use Gradio to build a simple chatbot prototype with a user-friendly interface. The chatbot will support multiple language models, allowing the user to switch models at any point during the conversation. It will also offer optional memory persistence, where the chat history is stored and forwarded to the selected model — which allows shared memory across models, even when switching mid-chat.\n",
    "\n",
    "In this project, we'll use OpenAI's API, Anthropic's Claude, and Meta's LLaMA, which runs locally via an Ollama server. Additionally, we'll use Python’s speech_recognition module to convert speech to text.\n",
    "\n",
    "It's worth noting that some APIs — such as OpenAI's — now support direct audio input, so integrating speech capabilities can also be done end-to-end without a separate transcription module."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "a07e7793-b8f5-44f4-aded-5562f633271a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import anthropic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "a0a343b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Speech recording and recognition libraries\n",
    "import speech_recognition as sr\n",
    "import sounddevice as sd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "d7693eda",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GUI prototyping\n",
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "41ffc0e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "buffer = [] # For temporarily holding sound recording\n",
    "\n",
    "#  Helper function for handling voice recording\n",
    "def callback(indata, frames, time, status):\n",
    "    buffer.append(indata.copy())\n",
    "\n",
    "stream = sd.InputStream(callback=callback, samplerate=16000, channels=1, dtype='int16')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "e9a79075",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Function for handling recording data and status\n",
    "def toggle_recording(state):\n",
    "    global stream, buffer\n",
    "    print('state', state)\n",
    "\n",
    "    if not state:\n",
    "        buffer.clear()\n",
    "        stream.start()\n",
    "        return gr.update(value=\"Stop Recording\"), 'Recording...', not state\n",
    "    else:\n",
    "        stream.stop()\n",
    "        audio = np.concatenate(buffer, axis=0)\n",
    "        text = transcribe(audio)\n",
    "        return gr.update(value=\"Start Recording\"), text, not state\n",
    "\n",
    "# Functio that converts speech to text via Google's voice recognition module\n",
    "def transcribe(recording, sample_rate=16000):\n",
    "    r = sr.Recognizer()\n",
    "\n",
    "    # Convert NumPy array to AudioData\n",
    "    audio_data = sr.AudioData(\n",
    "    recording.tobytes(),              # Raw byte data\n",
    "    sample_rate,                     # Sample rate\n",
    "        2                                # Sample width in bytes (16-bit = 2 bytes)\n",
    "    )\n",
    "\n",
    "    text = r.recognize_google(audio_data)\n",
    "    print(\"You said:\", text)\n",
    "    return text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcfb0190",
   "metadata": {},
   "source": [
    "### LLM & API set-up"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59416453",
   "metadata": {},
   "source": [
    "##### Load API keys from .env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "b638b822",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OpenAI API Key exists and begins sk-proj-\n",
      "Anthropic API Key exists and begins sk-ant-\n",
      "Google API Key not set\n"
     ]
    }
   ],
   "source": [
    "# Load environment variables in a file called .env\n",
    "# Print the key prefixes to help with any debugging\n",
    "\n",
    "load_dotenv(override=True)\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
    "google_api_key = os.getenv('GOOGLE_API_KEY')\n",
    "\n",
    "if openai_api_key:\n",
    "    print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
    "else:\n",
    "    print(\"OpenAI API Key not set\")\n",
    "    \n",
    "if anthropic_api_key:\n",
    "    print(f\"Anthropic API Key exists and begins {anthropic_api_key[:7]}\")\n",
    "else:\n",
    "    print(\"Anthropic API Key not set\")\n",
    "\n",
    "if google_api_key:\n",
    "    print(f\"Google API Key exists and begins {google_api_key[:8]}\")\n",
    "else:\n",
    "    print(\"Google API Key not set\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e6ae162",
   "metadata": {},
   "source": [
    "### Class for handling API calls and routing requests to the selected models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "268ea65d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LLMHandler:\n",
    "    def __init__(self, system_message: str = '', ollama_api:str='http://localhost:11434/api/chat'):\n",
    "        # Default system message if none provided\n",
    "        self.system_message = system_message if system_message else \"You are a helpful assistant. Always reply in Markdown\"\n",
    "        self.message_history = []\n",
    "\n",
    "        # Initialize LLM clients\n",
    "        self.openai = OpenAI()\n",
    "        self.claude = anthropic.Anthropic()\n",
    "        self.OLLAMA_API = ollama_api\n",
    "        self.OLLAMA_HEADERS = {\"Content-Type\": \"application/json\"}\n",
    "\n",
    "    def llm_call(self, model: str = 'gpt-4o-mini', prompt: str = '', memory_persistence=True):\n",
    "        if not model:\n",
    "            return 'No model specified'\n",
    "\n",
    "        # Use full message template with system prompt if no prior history\n",
    "        message = self.get_message_template(prompt, initial=True) if (\n",
    "            not self.message_history and not 'claude' in model\n",
    "             ) else self.get_message_template(prompt)\n",
    "\n",
    "        # Handle memory persistence\n",
    "        if memory_persistence:\n",
    "            self.message_history.extend(message)\n",
    "        else:\n",
    "            self.message_history = message\n",
    "\n",
    "        # Model-specific dispatch\n",
    "        try:\n",
    "            if 'gpt' in model:\n",
    "                response = self.call_openai(model=model)\n",
    "            elif 'claude' in model:\n",
    "                response = self.call_claude(model=model)\n",
    "            elif 'llama' in model:\n",
    "                response = self.call_ollama(model=model)\n",
    "            else:\n",
    "                response = f'{model.title()} is not supported or not a valid model name.'\n",
    "        except Exception as e:\n",
    "            response = f'Failed to retrieve response. Reason: {e}'\n",
    "\n",
    "        # Save assistant's reply to history if memory is enabled\n",
    "        if memory_persistence:\n",
    "            self.message_history.append({\n",
    "                \"role\": \"assistant\",\n",
    "                \"content\": response\n",
    "            })\n",
    "\n",
    "        return response\n",
    "\n",
    "    def get_message_template(self, prompt: str = '', initial=False):\n",
    "        # Returns a message template with or without system prompt\n",
    "        initial_template = [\n",
    "            {\"role\": \"system\", \"content\": self.system_message},\n",
    "            {\"role\": \"user\", \"content\": prompt}\n",
    "        ]\n",
    "        general_template = [\n",
    "            {\"role\": \"user\", \"content\": prompt}\n",
    "        ]\n",
    "        return initial_template if initial else general_template\n",
    "\n",
    "    def call_openai(self, model: str = 'gpt-4o-mini'):\n",
    "        # Sends chat completion request to OpenAI API\n",
    "        completion = self.openai.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=self.message_history,\n",
    "        )\n",
    "        response = completion.choices[0].message.content\n",
    "        return response\n",
    "\n",
    "    def call_ollama(self, model: str = \"llama3.2\"):\n",
    "\n",
    "        payload = {\n",
    "            \"model\": model,\n",
    "            \"messages\": self.message_history,\n",
    "            \"stream\": False\n",
    "        }\n",
    "\n",
    "        response = requests.post(url=self.OLLAMA_API, headers=self.OLLAMA_HEADERS, json=payload)\n",
    "        return response.json()[\"message\"][\"content\"]\n",
    "\n",
    "    def call_claude(self, model: str = \"claude-3-haiku-20240307\"):\n",
    "        # Sends chat request to Anthropic Claude API\n",
    "        message = self.claude.messages.create(\n",
    "            model=model,\n",
    "            system=self.system_message,\n",
    "            messages=self.message_history,\n",
    "            max_tokens=500\n",
    "        )\n",
    "        response = message.content[0].text\n",
    "        return response\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "632e618b",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_handler = LLMHandler()\n",
    "\n",
    "# Function to handle user prompts received by the interface\n",
    "def llm_call(model, prompt, memory_persistence):\n",
    "    response = llm_handler.llm_call(model=model, prompt=prompt, memory_persistence=memory_persistence)\n",
    "    return response, ''\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "e19228f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Specify available model names for the dropdown component\n",
    "AVAILABLE_MODELS = [\"gpt-4\", \"gpt-3.5\", \"claude-3-haiku-20240307\", \"llama3.2\", \"gpt-4o-mini\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "f65f43ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7868\n",
      "* To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7868/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "with gr.Blocks() as demo:\n",
    "    state = gr.State(False) # Recording state (on/off)\n",
    "    with gr.Row():\n",
    "        \n",
    "        with gr.Column():\n",
    "            out = gr.Markdown(label='Message history')\n",
    "            with gr.Row():\n",
    "                memory = gr.Checkbox(label='Toggle memory', value=True) # Handle memory status (on/off) btn\n",
    "                model_choice = gr.Dropdown(label='Model', choices=AVAILABLE_MODELS, interactive=True) # Model selection dropdown\n",
    "            query_box = gr.Textbox(label='ChatBox', placeholder=\"Your message\")\n",
    "            record_btn = gr.Button(value='Record voice message') # Start/stop recording btn\n",
    "            send_btn = gr.Button(\"Send\") # Send prompt btn\n",
    "      \n",
    "            \n",
    "    \n",
    "    record_btn.click(fn=toggle_recording, inputs=state, outputs=[record_btn, query_box, state])\n",
    "    send_btn.click(fn=llm_call, inputs=[model_choice, query_box, memory], outputs=[out, query_box])\n",
    "    \n",
    "\n",
    "demo.launch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3743db5d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "general_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
